#!/usr/bin/env python

"""
This script combines the outputs from the `bit-kraken2-to-taxon-summaries` program. 
"""

import sys
import argparse
import pandas as pd
import os

parser = argparse.ArgumentParser(description="This script combines the outputs from the `bit-kraken2-to-taxon-summaries` program. \
                                              For version info, run `bit-version`.")

required = parser.add_argument_group('required arguments')

required.add_argument("-i", "--input_files", nargs="+", type=str, help="space-delimited list of `bit-kraken2-to-taxon-summaries` output files, can be provided with shell wildcards", action="store", dest="input_files", required=True)
parser.add_argument("-n", "--sample_names", help='Sample names provided as a comma-delimited list, be sure it matches the order of the input files (by default will use basename of input files up to last period)', action="store", default='', dest="sample_names")
parser.add_argument("-o", "--output_file", help='Output combined summaries (default: "combined-kraken2-taxon-summaries.tsv")', action="store", dest="output_file", default="combined-kraken2-taxon-summaries.tsv")

if len(sys.argv)==1:
  parser.print_help(sys.stderr)
  sys.exit(0)

args = parser.parse_args()

# setting up variable
# file name is key, and sample name is value
all_samples = {}

# setting sample names and intializing counts
if len(args.sample_names) == 0:
    for file in args.input_files:
        curr_sample = os.path.basename(file).rsplit('.', 1)[0]
        
        if file in all_samples:
            print('\n    It seems the file "' + file + '" is trying to get in here twice.')
            print("\n    That's not gonna fly :(\n")
            sys.exit(1)

        all_samples[file] = curr_sample

else:

    # checking if sample names provided the length equals the number of input files
    if len(args.sample_names.split(",")) != len(args.input_files):
        print("\n    It seems the number of provided sample names doesn't match the number of provided input files :(")
        print("\n    Check usage with `bit-combine-kraken2-taxon-summaries -h`.\n")
        sys.exit(0)

    # setting iterator
    i = 0

    for curr_sample in args.sample_names.split(","):
        
        all_samples[args.input_files[i]] = curr_sample
        i += 1

# keeping dictionary of all taxids (keys) and full lineages (values)
taxid_dict = {}

# building final table
building_tab = pd.DataFrame(columns=["taxid"])

## working on each file
for sample_key in all_samples:

    # reading current file into pandas dataframe
    curr_tab = pd.read_csv(sample_key, sep="\t")

    # adding to building taxid dictionary
    for row in curr_tab.itertuples():
        if row[1] not in taxid_dict:
            taxid_dict[row[1]] = {'domain': row[2], 'phylum': row[3], 'class': row[4], 'order': row[5], 'family': row[6], 'genus': row[7], 'species': row[8]}

    # trimming down current table
    curr_sub_tab = curr_tab[["taxid", "read_counts", "percent_of_reads"]]

    # and changing count and percent column names to match sample ID
    curr_sub_tab.columns = ['taxid', str(all_samples[sample_key]) + "_read_counts", str(all_samples[sample_key]) + "_perc_of_reads"]

    # merging with master tab on GO_term
    building_tab = building_tab.merge(curr_sub_tab, on="taxid", how="outer")

## replacing NAs with 0s
building_tab = building_tab.fillna(0)

## making taxid dictionary into dataframe and merging into final table
taxid_df = pd.DataFrame.from_dict(taxid_dict, orient="index")

# moving index to column and renaming
taxid_df.reset_index(inplace=True)
taxid_df.rename(columns = {'index': 'taxid'}, inplace=True)

# merging
final_tab = taxid_df.merge(building_tab, on="taxid", how="outer")

# sorting
final_tab.sort_values(by=["taxid"], inplace=True)

# changing NAs to "NA"
final_tab = final_tab.fillna("NA")

## writing out
with open(args.output_file, "w") as out:
    out.write(final_tab.to_csv(index=False, sep="\t"))

